/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements. See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership. The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License. You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied. See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

#include <thrift/thrift-config.h>

#include <string>

#include <boost/lexical_cast.hpp>
#include <thrift/concurrency/Mutex.h>
#include <thrift/concurrency/IOService.h>
#include <thrift/transport/TSSLSocket.h>
#include <thrift/transport/TSSLContext.h>

using namespace std;
using namespace apache::thrift::concurrency;
using namespace boost::asio::ip;

namespace apache {
namespace thrift {
namespace transport {

static bool matchName(const char* host, const char* pattern, int size);
static char uppercase(char c);


// TSSLSocket implementation
TSSLSocket::TSSLSocket(const boost::shared_ptr<TSSLContext>& ctx)
  : server_(false)
  , sslctx_(ctx)
  , sslsocket_(IOService::getDefaultIOService(), sslctx_->getSSLCtx())
  , deadline_(IOService::getDefaultIOService())
  , host_("")
  , port_(0)
  , connTimeout_(0)
  , sendTimeout_(0)
  , recvTimeout_(0)
  , keepAlive_(false)
  , lingerOn_(1)
  , lingerVal_(0)
  , noDelay_(1)
  , maxRecvRetries_(5)
{
  
}

TSSLSocket::TSSLSocket(const boost::shared_ptr<TSSLContext>& ctx, const string& host, int port)
  : server_(false)
  , sslctx_(ctx)
  , sslsocket_(IOService::getDefaultIOService(), sslctx_->getSSLCtx())
  , deadline_(IOService::getDefaultIOService())
  , host_(host)
  , port_(port)
  , connTimeout_(0)
  , sendTimeout_(0)
  , recvTimeout_(0)
  , keepAlive_(false)
  , lingerOn_(1)
  , lingerVal_(0)
  , noDelay_(1)
  , maxRecvRetries_(5)
{

}

TSSLSocket::~TSSLSocket()
{
  close();
}

bool TSSLSocket::isOpen()
{
  return sslsocket_.lowest_layer().is_open();
}

bool TSSLSocket::peek()
{
  if (!isOpen()) {
    return false;
  }
  return true;
}

bool TSSLSocket::verify_certificate(bool preverified, boost::asio::ssl::verify_context& ctx)
{
  if (!sslctx_->checkPeerName()) {
    return true;
  }

  X509* cert = X509_STORE_CTX_get_current_cert(ctx.native_handle());
  if (cert == NULL) {
    TLogging::log_err("verifyCertificate: certificate not present");
    return false;
  }
  
  bool verified = false;
  std::string host;
  if (!sslctx_->peerFixedName().empty()) {
    host = sslctx_->peerFixedName();
  }

  try {
    STACK_OF(GENERAL_NAME)* alternatives = NULL;
    // only consider alternatives if we're not matching against a fixed name
    if (sslctx_->peerFixedName().empty()) {
      alternatives = (STACK_OF(GENERAL_NAME)*)X509_get_ext_d2i(cert, NID_subject_alt_name, NULL, NULL);
    }
    if (alternatives != NULL) {
      // verify subjectAlternativeName
      const int count = sk_GENERAL_NAME_num(alternatives);
      for (int i = 0; !verified && i < count; i++) {
        const GENERAL_NAME* name = sk_GENERAL_NAME_value(alternatives, i);
        if (name == NULL) {
          continue;
        }
        char* data = (char*)ASN1_STRING_data(name->d.ia5);
        int length = ASN1_STRING_length(name->d.ia5);
        switch (name->type) {
          case GEN_DNS: {
            if (host.empty()) {
              host = (server() ? getPeerHost() : getHost());
            }
            if (TSSLContext::matchName(host.c_str(), data, length)) {
              verified = true;
            }
            break;
          }
          case GEN_IPADD: {
            boost::system::error_code ec;
            boost::asio::ip::tcp::endpoint endpoint = sslsocket_.lowest_layer().remote_endpoint(ec);
            if (!ec) {
              boost::asio::ip::address remaddr = endpoint.address();
              if (remaddr.is_v4() && length == sizeof(boost::asio::ip::address_v4::bytes_type)) {
                boost::asio::ip::address_v4::bytes_type addrbytes = remaddr.to_v4().to_bytes();
                bool eq = true;
                for (int i = 0; i < length; ++i) {
                  if (addrbytes[i] != (unsigned char)data[i]) {
                    eq = false;
                  }
                }
                verified = eq;
              }
            }
            break;
          }
        }
      }
    }

    if (verified) {
      return true;
    }

    // verify commonName
    X509_NAME* name = X509_get_subject_name(cert);
    if (name != NULL) {
      X509_NAME_ENTRY* entry;
      unsigned char* utf8;
      int last = -1;
      while (!verified) {
        last = X509_NAME_get_index_by_NID(name, NID_commonName, last);
        if (last == -1) {
          break;
        }
        entry = X509_NAME_get_entry(name, last);
        if (entry == NULL) {
          continue;
        }
        ASN1_STRING* common = X509_NAME_ENTRY_get_data(entry);
        int size = ASN1_STRING_to_UTF8(&utf8, common);
        string fart((char*)utf8, size);
        if (host.empty()) {
          host = (server() ? getPeerHost() : getHost());
        }
        if (TSSLContext::matchName(host.c_str(), (char*)utf8, size)) {
          verified = true;
        }
        OPENSSL_free(utf8);
      }
    }
  }
  catch (const exception& e) {
    TLogging::log_err("exception occured while verify: %s", e.what());
    return false;
  }

  return verified;
}

struct ssl_connect_handler
{
  ssl_connect_handler(const boost::shared_ptr<boost::system::error_code>& ec,
                      const boost::shared_ptr<boost::mutex>& mux,
                      const boost::shared_ptr<boost::condition_variable>& cnd)
    : error(ec)
    , mutx(mux)
    , cond(cnd)
  {

  }

  void operator () (const boost::system::error_code& ec, tcp::resolver::iterator next)
  {
    boost::unique_lock<boost::mutex> lock(*mutx);
    *error = ec;
    lock.unlock();
    cond->notify_one();
  }

  void operator () (const boost::system::error_code& ec)
  {
    if (!ec) {
      boost::unique_lock<boost::mutex> lock(*mutx);
      *error = boost::asio::error::timed_out;
      lock.unlock();
      cond->notify_one();
    }
  }

  boost::shared_ptr<boost::system::error_code> error;
  boost::shared_ptr<boost::mutex> mutx;
  boost::shared_ptr<boost::condition_variable> cond;
};

struct ssl_handshake_handler
{
  ssl_handshake_handler(const boost::shared_ptr<boost::system::error_code>& ec,
                        const boost::shared_ptr<boost::mutex>& mux,
                        const boost::shared_ptr<boost::condition_variable>& cnd)
    : error(ec)
    , mutx(mux)
    , cond(cnd)
  {

  }

  void operator () (const boost::system::error_code& ec)
  {
    boost::unique_lock<boost::mutex> lock(*mutx);
    *error = ec;
    lock.unlock();
    cond->notify_one();
  }

  boost::shared_ptr<boost::system::error_code> error;
  boost::shared_ptr<boost::mutex> mutx;
  boost::shared_ptr<boost::condition_variable> cond;
};

void TSSLSocket::open()
{
  if (server_) {
    if (sslctx_->getVerificationMode() & boost::asio::ssl::verify_peer) {
      sslsocket_.set_verify_callback(boost::bind(&TSSLSocket::verify_certificate, this, _1, _2));
    }
    handshake(boost::asio::ssl::stream_base::server);
  }
  else {
    if (isOpen()) {
      throw TTransportException(TTransportException::BAD_ARGS);
    }
    if (sslctx_->getVerificationMode() & boost::asio::ssl::verify_peer) {
      sslsocket_.set_verify_callback(boost::bind(&TSSLSocket::verify_certificate, this, _1, _2));
    }
    // Validate port number
    if (port_ < 0 || port_ > 0xFFFF) {
      throw TTransportException(TTransportException::NOT_OPEN, "Specified port is invalid");
    }
    openConnection();
  }
}

void TSSLSocket::openConnection()
{
  if (sendTimeout_ > 0) {
    setSendTimeout(sendTimeout_);
  }

  if (recvTimeout_ > 0) {
    setRecvTimeout(recvTimeout_);
  }

  if (keepAlive_) {
    setKeepAlive(keepAlive_);
  }

  setLinger(lingerOn_, lingerVal_);

  setNoDelay(noDelay_);


  tcp::resolver::query query(host_, apache::thrift::to_string(port_));
  tcp::resolver::iterator iter = tcp::resolver(IOService::getDefaultIOService()).resolve(query);

  boost::shared_ptr<boost::system::error_code> ec(new boost::system::error_code(boost::asio::error::would_block));
  boost::shared_ptr<boost::mutex> mux(new boost::mutex);
  boost::shared_ptr<boost::condition_variable> cond(new boost::condition_variable);
  ssl_connect_handler conn_handler(ec, mux, cond);
  boost::asio::async_connect(sslsocket_.lowest_layer(), iter, conn_handler);

  if (connTimeout_ > 0) {
    deadline_.expires_from_now(boost::posix_time::milliseconds(connTimeout_));
    deadline_.async_wait(conn_handler);
  }

  boost::unique_lock<boost::mutex> lock(*mux);
  while (*ec == boost::asio::error::would_block) {
    cond->wait(lock);
  }

  if (*ec) {
    throw TTransportException(TTransportException::NOT_OPEN, "connect() failed", ec->value());
  }
  else {
    handshake(boost::asio::ssl::stream_base::client);
  }
 
}

void TSSLSocket::handshake(boost::asio::ssl::stream_base::handshake_type shaketype)
{
  boost::shared_ptr<boost::system::error_code> ec(new boost::system::error_code(boost::asio::error::would_block));
  boost::shared_ptr<boost::mutex> mux(new boost::mutex);
  boost::shared_ptr<boost::condition_variable> cond(new boost::condition_variable);
  ssl_handshake_handler shake_handler(ec, mux, cond);
  sslsocket_.async_handshake(shaketype, shake_handler);
  if (connTimeout_ > 0) {
    deadline_.expires_from_now(boost::posix_time::milliseconds(connTimeout_));
    deadline_.async_wait(shake_handler);
  }

  boost::unique_lock<boost::mutex> lock(*mux);
  while (*ec == boost::asio::error::would_block) {
    cond->wait(lock);
  }

  if (*ec) {
    TLogging::log_err(ec->message().c_str());
    throw TTransportException(TTransportException::NOT_OPEN, "handshake() failed", ec->value());
  }
  else {
      deadline_.cancel();
  }
}

void TSSLSocket::close()
{
  if (sslsocket_.lowest_layer().is_open()) {
    boost::system::error_code ec;
    sslsocket_.shutdown(ec);
  }
}

struct ssl_iohandler
{
  ssl_iohandler(const boost::shared_ptr<boost::system::error_code>& ec,
                const boost::shared_ptr<std::size_t>& length,
                const boost::shared_ptr<boost::mutex>& mux,
                const boost::shared_ptr<boost::condition_variable>& cnd)
    : error(ec)
    , bytes_transferred(length)
    , mutx(mux)
    , cond(cnd)
  {

  }

  void operator() (const boost::system::error_code& ec, std::size_t length)
  {
    boost::unique_lock<boost::mutex> lock(*mutx);
    if (ec) {
      *error = ec;
      lock.unlock();
      cond->notify_one();
    }
    else {
      *error = ec;
      *bytes_transferred += length;
      lock.unlock();
      cond->notify_one();
    }
  }

  void operator() (const boost::system::error_code& ec)
  {
    if (!ec) {
      boost::unique_lock<boost::mutex> lock(*mutx);
      *error = boost::asio::error::timed_out;
      lock.unlock();
      cond->notify_one();
    }
  }

  boost::shared_ptr<boost::system::error_code> error;
  boost::shared_ptr<std::size_t> bytes_transferred;
  boost::shared_ptr<boost::mutex> mutx;
  boost::shared_ptr<boost::condition_variable> cond;
};

uint32_t TSSLSocket::read(uint8_t* buf, uint32_t len)
{
  if (isOpen() == false) {
    throw TTransportException(TTransportException::NOT_OPEN, "Called read on non-open socket");
  }

  boost::shared_ptr<boost::system::error_code> ec(new boost::system::error_code(boost::asio::error::would_block));
  boost::shared_ptr<boost::mutex> mux(new boost::mutex);
  boost::shared_ptr<boost::condition_variable> cond(new boost::condition_variable);
  boost::shared_ptr<size_t> length(new size_t(0));
  ssl_iohandler read_handler(ec, length, mux, cond);

  sslsocket_.async_read_some(boost::asio::buffer(buf, len), read_handler);

  if (recvTimeout_) {
    deadline_.expires_from_now(boost::posix_time::milliseconds(recvTimeout_));
    deadline_.async_wait(read_handler);
  }

  boost::unique_lock<boost::mutex> lock(*mux);
  while (*ec == boost::asio::error::would_block) {
    if (*length > 0) {
      break;
    }
    cond->wait(lock);
  }

  if (*ec) {
    throw TTransportException(TTransportException::TIMED_OUT, "THRIFT_EAGAIN (timed out)");
  }
  else {
    deadline_.cancel();
  }

  // Pack data into string
  return (uint32_t)(*length);
}

uint32_t TSSLSocket::readAll(uint8_t* buf, uint32_t len)
{
  if (isOpen() == false) {
    throw TTransportException(TTransportException::NOT_OPEN, "Called read on non-open socket");
  }

  boost::shared_ptr<boost::system::error_code> ec(new boost::system::error_code(boost::asio::error::would_block));
  boost::shared_ptr<boost::mutex> mux(new boost::mutex);
  boost::shared_ptr<boost::condition_variable> cond(new boost::condition_variable);
  boost::shared_ptr<size_t> length(new size_t(0));
  ssl_iohandler read_handler(ec, length, mux, cond);

  boost::asio::async_read(sslsocket_, boost::asio::buffer(buf, len), read_handler);

  if (recvTimeout_) {
    deadline_.expires_from_now(boost::posix_time::milliseconds(recvTimeout_));
    deadline_.async_wait(read_handler);
  }

  boost::unique_lock<boost::mutex> lock(*mux);
  while (*ec == boost::asio::error::would_block) {
    if (*length == len) {
      break;
    }
    cond->wait(lock);
  }

  if (*ec) {
    throw TTransportException(TTransportException::TIMED_OUT, "THRIFT_EAGAIN (timed out)");
  }
  else {
    deadline_.cancel();
  }

  // Pack data into string
  return (uint32_t)(*length);
}

void TSSLSocket::write(const uint8_t* buf, uint32_t len)
{
  boost::shared_ptr<boost::system::error_code> ec(new boost::system::error_code(boost::asio::error::would_block));
  boost::shared_ptr<boost::mutex> mux(new boost::mutex);
  boost::shared_ptr<boost::condition_variable> cond(new boost::condition_variable);
  boost::shared_ptr<size_t> length(new size_t(0));
  ssl_iohandler write_handler(ec, length, mux, cond);

  boost::asio::async_write(sslsocket_, boost::asio::buffer(buf, len), write_handler);

  if (sendTimeout_ > 0) {
    deadline_.expires_from_now(boost::posix_time::milliseconds(sendTimeout_));
    deadline_.async_wait(write_handler);
  }

  while (*ec == boost::asio::error::would_block)
  {
    boost::unique_lock<boost::mutex> lock(*mux);
    if (*length == len) {
      break;
    }
    cond->wait(lock);
  }

  if (*ec) {
    throw TTransportException(TTransportException::TIMED_OUT, "send timeout expired");
  }
  else {
    deadline_.cancel();
  }
}

std::string TSSLSocket::getHost()
{
  return host_;
}

int TSSLSocket::getPort()
{
  return port_;
}

void TSSLSocket::setHost(string host)
{
  host_ = host;
}

void TSSLSocket::setPort(int port)
{
  port_ = port;
}

void TSSLSocket::setLinger(bool on, int linger)
{
  lingerOn_ = on;
  lingerVal_ = linger;
  if (isOpen() == false) {
    return;
  }
  boost::asio::socket_base::linger option(on, linger);
  sslsocket_.lowest_layer().set_option(option);
}

void TSSLSocket::setNoDelay(bool noDelay)
{
  noDelay_ = noDelay;
  if (sslsocket_.lowest_layer().is_open() == false) {
    return;
  }

  boost::asio::ip::tcp::no_delay option(noDelay);
  sslsocket_.lowest_layer().set_option(option);
}

void TSSLSocket::setConnTimeout(int ms)
{
  connTimeout_ = ms;
}


void TSSLSocket::setRecvTimeout(int ms)
{
  recvTimeout_ = ms;
}

void TSSLSocket::setSendTimeout(int ms)
{
  sendTimeout_ = ms;
}

void TSSLSocket::setKeepAlive(bool keepAlive)
{
  keepAlive_ = keepAlive;

  if (sslsocket_.lowest_layer().is_open() == false) {
    return;
  }

  boost::asio::socket_base::keep_alive option(keepAlive);
  sslsocket_.lowest_layer().set_option(option);
}

void TSSLSocket::setMaxRecvRetries(int maxRecvRetries)
{
  maxRecvRetries_ = maxRecvRetries;
}

string TSSLSocket::getSocketInfo()
{
  std::ostringstream oss;
  if (host_.empty() || port_ == 0) {
    oss << "<Host: " << getPeerAddress();
    oss << " Port: " << getPeerPort() << ">";
  }
  else {
    oss << "<Host: " << host_ << " Port: " << port_ << ">";
  }
  return oss.str();
}

std::string TSSLSocket::getPeerHost()
{
  if (peerHost_.empty()) {

    if (sslsocket_.lowest_layer().is_open() == false) {
      return host_;
    }
    boost::system::error_code ec;
    boost::asio::ip::tcp::endpoint endpoint = sslsocket_.lowest_layer().remote_endpoint(ec);
    if (!ec) {
      peerHost_ = endpoint.address().to_string();
      peerPort_ = endpoint.port();
      peerAddress_ = peerHost_;
    }

  }
  return peerHost_;
}

std::string TSSLSocket::getPeerAddress()
{
  if (peerHost_.empty()) {
    if (sslsocket_.lowest_layer().is_open() == false) {
      return peerAddress_;
    }
    boost::system::error_code ec;
    boost::asio::ip::tcp::endpoint endpoint = sslsocket_.lowest_layer().remote_endpoint(ec);
    if (!ec) {
      peerHost_ = endpoint.address().to_string();
      peerPort_ = endpoint.port();
      peerAddress_ = peerHost_;
    }
  }
  return peerAddress_;
}

int TSSLSocket::getPeerPort()
{
  getPeerAddress();
  return peerPort_;
}

const std::string TSSLSocket::getOrigin()
{
  std::ostringstream oss;
  oss << getPeerHost() << ":" << getPeerPort();
  return oss.str();
}

// TSSLSocketFactory implementation

TSSLSocketFactory::TSSLSocketFactory(const boost::shared_ptr<TSSLContext>& ctx)
  : ctx_(ctx)
  , server_(false)
{
  
}

TSSLSocketFactory::~TSSLSocketFactory()
{
 
}

boost::shared_ptr<TSSLSocket> TSSLSocketFactory::createSocket()
{
  boost::shared_ptr<TSSLSocket> ssl(new TSSLSocket(ctx_));
  ssl->server(server());
  return ssl;
}

boost::shared_ptr<TSSLSocket> TSSLSocketFactory::createSocket(const string& host, int port)
{
  boost::shared_ptr<TSSLSocket> ssl(new TSSLSocket(ctx_, host, port));
  ssl->server(server());
  return ssl;
}

}
}
}
